# -*- coding: utf-8 -*-
import functools
import inspect
import numbers

import jams
import numpy as np
import os
import os.path as osp
import shutil
import pprint

from .logger import create_logger, DesedError


def _check_random_state(seed):
    """Turn seed into a np.random.RandomState instance

    Parameters
    ----------
    seed : None | int | instance of RandomState
        If seed is None, return the RandomState singleton used by np.random.
        If seed is an int, return a new RandomState instance seeded with seed.
        If seed is already a RandomState instance, return it.
        Otherwise raise ValueError.
    """
    if seed is None or seed is np.random:
        return np.random.mtrand._rand
    elif isinstance(seed, (numbers.Integral, np.integer, int)):
        return np.random.RandomState(seed)
    elif isinstance(seed, np.random.RandomState):
        return seed
    else:
        raise ValueError('%r cannot be used to seed a numpy.random.RandomState'
                         ' instance' % seed)


def create_folder(folder, exist_ok=True, delete_if_exists=False):
    """ Create folder (and parent folders) if not exists.

    Args:
        folder: str, path of folder(s) to create.
        delete_if_exists: bool, True if you want to delete the folder when exists

    Returns:
        None
    """
    if delete_if_exists:
        if os.path.exists(folder):
            shutil.rmtree(folder)
            os.mkdir(folder)

    os.makedirs(folder, exist_ok=exist_ok)


pp = pprint.PrettyPrinter()
def pprint(x):
    pp.pprint(x)


def choose_cooccurence_class(co_occur_params, random_state=None):
    """ Choose another class given a dictionary of parameters (from an already specified class).
    Args:
        co_occur_params: dict, define the parameters of co-occurence of classes
            Example of co_occur_params dictionnary::
                {
                  "max_events": 13,
                  "classes": [
                    "Alarm_bell_ringing",
                    "Dog",
                  ],
                  "probas": [
                    70,
                    30
                  ]
                }
            classes and probas maps each others
        random_state: int, or RandomS0tate object
    Returns:
        str, the class name.
    """
    accumulated_probas = 0
    inter_acc_probas = []
    # Accumulate probas until 1
    for i in range(len(co_occur_params['probas'])):
        accumulated_probas += co_occur_params['probas'][i]
        inter_acc_probas.append(accumulated_probas)
    if random_state is not None:
        random_state = _check_random_state(random_state)
        random_val = random_state.rand()
    else:
        # Get a random value between 0-1
        random_val = np.random.uniform()
    # Get the index of the chosen class by taking the index of the first accumulated value > random_val
    idx_chosen_class = np.argmax(np.asarray(inter_acc_probas) > random_val)
    chosen_class = co_occur_params['classes'][idx_chosen_class]
    return chosen_class


def change_snr(jams_path, db_change):
    """ Modify the background SNR of a JAMS generated by scaper
    Args:
        jams_path: str, jams path of file created by Scaper
        db_change: float, Signal to noise ratio (SNR), in dB to be added

    Returns:
        jam_obj the jams object that has been modified
    """
    jams_obj = jams.load(jams_path)
    ann = jams_obj.annotations.search(namespace='scaper')[0]
    for cnt, obs in enumerate(ann.data):
        if obs.value["role"] == "foreground":
            old_snr = ann.data[cnt].value["snr"]
            # Changing manually the jams to reduce the SNR (note that snr can be negative)
            ann.data[cnt].value["snr"] = old_snr + db_change

    return jams_obj


def modify_fg_onset(jams_path, slice_seconds):
    """ Add a value foreground onset of a JAMS generated by scaper (containing a single event)
    Args:
        jams_path: str, the name of the JAMS file to change the background SNR
        slice_seconds: float, value in seconds, value to be added to previous onset

    Returns:
        jams object that has been modified
    """
    jams_obj = jams.load(jams_path)
    ann = jams_obj.annotations.search(namespace='scaper')[0]
    data = ann.data
    for cnt, obs in enumerate(data):
        if obs.value["role"] == "foreground":
            onset = obs.value["event_time"]
            # Checking the new onset is possible
            if onset + slice_seconds > ann.duration:
                raise DesedError(f"The new onset is not valid: {onset + slice_seconds} > {ann.duration}, "
                                 f"for file: {jams_path}")
            elif onset + slice_seconds < 0:
                raise DesedError(f"The new onset is not valid: {onset + slice_seconds} < 0, "
                                 f"for file: {jams_path}")
            else:
                # Change source time by adding the added value specified
                ann.data[cnt].value["event_time"] = onset + slice_seconds
                # Todo, this is tricky, because it is an object, find a better way to do that with Scaper
                new_obs = ann.data[cnt]._replace(time=onset + slice_seconds)
                del ann.data[cnt]
                ann.data.add(new_obs)

    return jams_obj


def modify_jams(list_jams, modify_function, out_dir=None, **kwargs):
    """ Function to modify jams files
    Args:
        modify_function: function, a function that takes (jams_path, **kwargs) as input, and
            return the modified jams object (jams.JAMS)
        list_jams: list, the list of jams_path to be modified
        out_dir: str, the path of the directory of the new jams. If not defined, overwrite the jam_file

        **kwargs: arguments that'll be given to modify_function

    Returns: None

    Examples:
        There are two ways of using this function, (example with snr diminution):
        >>> jams_to_modify = ["material/5.jams"]
        1-
        >>> modify_jams(jams_to_modify, change_snr, db_added=-6)
        2- Recommended
        >>> decrease_snr = functools.partial(change_snr, db_added =-6)
        >>> modify_jams(jams_to_modify, decrease_snr)
    """
    logger = create_logger(__name__ + "/" + inspect.currentframe().f_code.co_name)
    create_folder(out_dir)
    new_list_jams = []
    for jam_file in list_jams:
        logger.debug(jam_file)
        jam_obj = modify_function(jam_file, **kwargs)

        if out_dir is not None:
            out_jams = osp.join(out_dir, os.path.basename(jam_file))
        else:
            out_jams = jam_file
        jam_obj.save(out_jams)
        new_list_jams.append(out_jams)

    return new_list_jams

